Skip to content

Add with_sharding_constraint operator#297

Open
fmassa wants to merge 2 commits intomainfrom
fmassa/with_sharding_constraint
Open

Add with_sharding_constraint operator#297
fmassa wants to merge 2 commits intomainfrom
fmassa/with_sharding_constraint

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Jan 21, 2026

Summary

  • Adds a with_sharding_constraint operator similar to JAX's API for constraining intermediate tensor shardings
  • Uses local_map internally with an identity function for a simple implementation
  • Exports the new operator from the public API

Implementation Details

The with_sharding_constraint(x, shardings, device_mesh=None) function:

  • Takes a tensor, desired placements, and optional device mesh
  • Uses local_map with matching in_placements and out_placements to enforce the constraint
  • Falls back to global mesh from enclosing local_map region if device_mesh is not provided
  • Uses clone() inside the identity function to avoid input-to-output aliasing issues with dynamo's HOP tracing

Example Usage

from autoparallel import with_sharding_constraint
from torch.distributed.tensor.placement_types import Shard, Replicate

# Constrain intermediate tensor to be sharded
x = with_sharding_constraint(x, (Shard(0), Replicate()), device_mesh)

Test plan

  • Test with explicit device mesh
  • Test between local_map regions
  • Test forcing replication
  • Test with 2D mesh
  • Test multiple constraints in sequence
  • Test error when no mesh is available
  • Verify placements match requested constraints in optimizer output

@fmassa fmassa requested review from wconstab and xmfan January 21, 2026 14:11
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 21, 2026
def forward(self, x):
x = self.linear1(x)
# Constrain intermediate result to be sharded
x = with_sharding_constraint(x, (Shard(0),), device_mesh_1d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thought: it we used 'x.redistribute(Shard(0))` which is valid dtensor code as a way for autop to infer constraints, would that be a way to avoid having autop vs normal dtensor code diverge?

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems OK to me, but i do have a question about whether we'd want to introduce autop-specific code into the model code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants